{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 最简单的层次网络\n",
    "class simpleNet(nn.Module) :\n",
    "    def __init__(self, in_dim, hidden_dim, out_dim) :\n",
    "        super().__init__()\n",
    "        i, self.layer = 1, nn.Sequential()\n",
    "        for h_dim in hidden_dim :\n",
    "            self.layer.add_module('layer_{}'.format(i), nn.Linear(in_dim, h_dim))\n",
    "            i, in_dim = i + 1, h_dim\n",
    "        self.layer.add_module('layer_{}'.format(i), nn.Linear(in_dim, out_dim))\n",
    "        self.layerNum = i\n",
    "        \n",
    "    def forward(self, x) :\n",
    "        x = self.layer(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 添加激活函数，增加网络的非线性\n",
    "\n",
    "class Activation_Net(nn.Module) :\n",
    "    def __init__(self, in_dim, hidden_dim, out_dim) :\n",
    "        super().__init__()\n",
    "        i, self.layer = 1, nn.Sequential()\n",
    "        for h_dim in hidden_dim :\n",
    "            self.layer.add_module('layer_{}'.format(i), nn.Sequential(nn.Linear(in_dim, h_dim), nn.ReLU(True)))\n",
    "            i, in_dim = i + 1, h_dim\n",
    "        self.layer.add_module('layer_{}'.format(i), nn.Sequential(nn.Linear(in_dim, out_dim)))\n",
    "        self.layerNum = i\n",
    "    def forward(self, x) :\n",
    "        x = self.layer(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 添加批标准化的网络\n",
    "\n",
    "class Batch_net(nn.Module) :\n",
    "    def __init__(self, in_dim, hidden_dim, out_dim) :\n",
    "        super().__init__()\n",
    "        i, self.layer = 1, nn.Sequential()\n",
    "        for h_dim in hidden_dim :\n",
    "            self.layer.add_module('layer_{}'.format(i), \n",
    "                                  nn.Sequential(nn.Linear(in_dim, h_dim), nn.BatchNorm1d(h_dim), nn.ReLU(True)))\n",
    "            i, in_dim = i + 1, h_dim\n",
    "        self.layer.add_module('layer_{}'.format(i), nn.Sequential(nn.Linear(in_dim, out_dim)))\n",
    "        self.layerNum = i\n",
    "    def forward(self, x) :\n",
    "        x = self.layer(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 确定超参数\n",
    "\n",
    "epoch_size = 700\n",
    "learning_rate = 1e-2\n",
    "num_epoches = 60"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "这里本身自带手写数字的数据集，是从kaggle上下载的。给出[链接](https://www.kaggle.com/c/digit-recognizer/data)\n",
    "\n",
    "保存为名为*/data* 的文件夹，文件结构如下 :\n",
    "```\n",
    "data\n",
    "├── sample_submission.csv\n",
    "├── test.csv\n",
    "└── train.csv\n",
    "\n",
    "0 directories, 3 files\n",
    "```\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "42000 784\n"
     ]
    }
   ],
   "source": [
    "# 获得训练数据 - train.csv\n",
    "\n",
    "import csv\n",
    "with open('./data/train.csv') as f :\n",
    "    lines = csv.reader(f)\n",
    "    label, attr = [], []\n",
    "    for line in lines :\n",
    "        if lines.line_num == 1 :\n",
    "            continue\n",
    "        label.append(int(line[0]))\n",
    "        attr.append([float(j) for j in line[1:]])\n",
    "print(len(label), len(attr[1]))\n",
    "\n",
    "# 将数据分为 60(epoches) * 700(rows) 的数据集\n",
    "epoches = []\n",
    "for i in range(0, len(label), epoch_size) :\n",
    "    torch_attr = torch.FloatTensor(attr[i : i + epoch_size])\n",
    "    torch_label = torch.LongTensor(label[i : i + epoch_size])\n",
    "    epoches.append((torch_attr, torch_label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模型实例化，Activation_Net,  \n",
    "if torch.cuda.is_available() :\n",
    "    net = Activation_Net(28 * 28, [300, 100], 10).cuda()\n",
    "else :\n",
    "    net = Activation_Net(28 * 28, [300, 100], 10)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 损失函数\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "# 优化函数\n",
    "optimizer = optim.SGD(net.parameters(), lr = learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- train time 1 ---\n",
      "average loss = 11.5832\n",
      "average correct number = 0.4225\n",
      "--- train time 2 ---\n",
      "average loss = 0.5315\n",
      "average correct number = 0.8423\n",
      "--- train time 3 ---\n",
      "average loss = 0.3385\n",
      "average correct number = 0.8999\n",
      "--- train time 4 ---\n",
      "average loss = 0.2668\n",
      "average correct number = 0.9208\n",
      "--- train time 5 ---\n",
      "average loss = 0.2259\n",
      "average correct number = 0.9326\n",
      "--- train time 6 ---\n",
      "average loss = 0.2012\n",
      "average correct number = 0.9394\n",
      "--- train time 7 ---\n",
      "average loss = 0.1813\n",
      "average correct number = 0.9449\n",
      "--- train time 8 ---\n",
      "average loss = 0.1658\n",
      "average correct number = 0.9500\n",
      "--- train time 9 ---\n",
      "average loss = 0.1531\n",
      "average correct number = 0.9541\n",
      "--- train time 10 ---\n",
      "average loss = 0.1429\n",
      "average correct number = 0.9567\n",
      "--- train time 11 ---\n",
      "average loss = 0.1339\n",
      "average correct number = 0.9595\n",
      "--- train time 12 ---\n",
      "average loss = 0.1257\n",
      "average correct number = 0.9615\n",
      "--- train time 13 ---\n",
      "average loss = 0.1180\n",
      "average correct number = 0.9641\n",
      "--- train time 14 ---\n",
      "average loss = 0.1121\n",
      "average correct number = 0.9661\n",
      "--- train time 15 ---\n",
      "average loss = 0.1066\n",
      "average correct number = 0.9678\n",
      "--- train time 16 ---\n",
      "average loss = 0.1019\n",
      "average correct number = 0.9692\n",
      "--- train time 17 ---\n",
      "average loss = 0.0966\n",
      "average correct number = 0.9708\n",
      "--- train time 18 ---\n",
      "average loss = 0.0925\n",
      "average correct number = 0.9723\n",
      "--- train time 19 ---\n",
      "average loss = 0.0883\n",
      "average correct number = 0.9735\n",
      "--- train time 20 ---\n",
      "average loss = 0.0843\n",
      "average correct number = 0.9747\n",
      "--- train time 40 ---\n",
      "average loss = 0.0390\n",
      "average correct number = 0.9904\n",
      "--- train time 60 ---\n",
      "average loss = 0.0206\n",
      "average correct number = 0.9961\n",
      "--- train time 80 ---\n",
      "average loss = 0.0112\n",
      "average correct number = 0.9985\n",
      "--- train time 100 ---\n",
      "average loss = 0.0065\n",
      "average correct number = 0.9995\n",
      "--- train time 120 ---\n",
      "average loss = 0.0039\n",
      "average correct number = 0.9998\n",
      "--- train time 140 ---\n",
      "average loss = 0.0026\n",
      "average correct number = 0.9999\n",
      "--- train time 160 ---\n",
      "average loss = 0.0019\n",
      "average correct number = 1.0000\n",
      "--- train time 180 ---\n",
      "average loss = 0.0014\n",
      "average correct number = 1.0000\n",
      "--- train time 200 ---\n",
      "average loss = 0.0011\n",
      "average correct number = 1.0000\n",
      "--- train time 220 ---\n",
      "average loss = 0.0009\n",
      "average correct number = 1.0000\n",
      "--- train time 240 ---\n",
      "average loss = 0.0008\n",
      "average correct number = 1.0000\n",
      "--- train time 260 ---\n",
      "average loss = 0.0007\n",
      "average correct number = 1.0000\n",
      "--- train time 280 ---\n",
      "average loss = 0.0006\n",
      "average correct number = 1.0000\n",
      "--- train time 300 ---\n",
      "average loss = 0.0005\n",
      "average correct number = 1.0000\n"
     ]
    }
   ],
   "source": [
    "# 训练过程\n",
    "def train() :\n",
    "    epoch_num, loss_sum, cort_num_sum = 0, 0.0, 0\n",
    "    for epoch in epoches :\n",
    "        epoch_num += 1\n",
    "        if torch.cuda.is_available() :\n",
    "            inputs = Variable(epoch[0]).cuda()\n",
    "            target = Variable(epoch[1]).cuda()\n",
    "        else :\n",
    "            inputs = Variable(epoch[0])\n",
    "            target = Variable(epoch[1])\n",
    "        output = net(inputs)\n",
    "        loss = criterion(output, target)\n",
    "        # reset gradients\n",
    "        optimizer.zero_grad()\n",
    "        # backward pass\n",
    "        loss.backward()\n",
    "        # update parameters\n",
    "        optimizer.step()\n",
    "        \n",
    "        # get training infomation\n",
    "        loss_sum += loss.data[0]\n",
    "        _, pred = torch.max(output.data, 1)\n",
    "        \n",
    "        #print(pred.shape)\n",
    "        #print(epoch[1].shape)\n",
    "        \n",
    "        num_correct = torch.eq(pred, epoch[1]).sum()\n",
    "        cort_num_sum += num_correct\n",
    "        \n",
    "    loss_avg = loss_sum / epoch_num\n",
    "    cort_num_avg = cort_num_sum / epoch_num / epoch_size\n",
    "    return loss_avg, cort_num_avg\n",
    "\n",
    "# 对所有数据跑300遍模型\n",
    "loss, correct = [], []\n",
    "training_time = 300\n",
    "for i in range(1, training_time + 1) :\n",
    "    loss_avg, correct_num_avg = train()\n",
    "    loss.append(loss_avg)\n",
    "    if i< 20 or i % 20 == 0 :\n",
    "        print('--- train time {} ---'.format(i))\n",
    "        print('average loss = {:.4f}'.format(loss_avg))\n",
    "        print('average correct number = {:.4f}'.format(correct_num_avg))\n",
    "    correct.append(correct_num_avg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAFz9JREFUeJzt3XuQHWWdxvHnOTOTCSaRW0YkAkKUxaV0UXYWQS3dEryxrmiVVuF6X7eo8rZeS3G1vGyp5WW11BV1I7LiqqiLulLeFlBYtUrRCaJcIhIwaoAkgwgkICEz57d/9HtmTibT3TPnnMk57/D9VE2dPn16un+dhuf0vO/b3Y4IAQDy1+h3AQCA3iDQAWCZINABYJkg0AFgmSDQAWCZINABYJkg0LFkbG+xfVq/65Ak2++xfZvtbUu4jV221/d6WWChCHQse7aPkvRGScdHxIPn+fxvbW/tdjsRsToibur1ssBCEei4PzhK0h8jYkenK7A93MN6gCVBoGO/sD1q+6O2b0k/H7U9mj5ba/tbtu+wfbvtH9lupM/eYvtm2zttX2/71JL1H2j787Ynbf/O9tttN1KTzyWS1qVmjs/N+b1Vkr7b9vku2+tsv8v2hba/YPsuSS+1fZLtn6Q6b7X9Cdsr2tYVth+epj9n+xzb3061X2H7YR0u+9S073fa/qTt/7P9T705MlhOCHTsL2+TdLKkR0s6QdJJkt6ePnujpK2SxiQdJulfJIXt4yS9WtLfRMQaSU+TtKVk/f8u6UBJ6yU9SdKLJb0sIi6V9AxJt6Rmjpe2/1JE3D3n89URcUv6+AxJF0o6SNIXJU1Ler2ktZJOkXSqpFdW7POZkt4t6WBJmyW9d7HL2l6banirpEMlXS/pcRXrwf0YgY795QWS/jUidkTEpIrwelH6bI+kwyU9NCL2RMSPorjJ0LSkUUnH2x6JiC0RcePcFdseUhGIb42InRGxRdKH29bfqZ9ExP9ERDMi/hwRGyPipxExlbbxHyq+PMp8IyJ+FhFTKr4QHt3BsqdLujYivp4++7ikJevYRd4IdOwv6yT9ru3979I8SfqQirPSi23fZPtsSYqIzZJeJ+ldknbY/rLtddrXWkkj86z/IV3W/If2N7b/IjUNbUvNMO9L2y7THrz3SFrdwbLr2utIX3Rdd+BieSLQsb/cIumhbe+PSvOUzqrfGBHrJT1L0htabeUR8aWIeEL63ZD0gXnWfZuKs/y56795gbWV3XJ07vxPSfq1pGMj4oEqmoa8wG106lZJR7Te2Hb7e6AdgY795QJJb7c9ltqF3yHpC5Jk+5m2H57C6k4VTS1N28fZfnLqPL1X0p8lNeeuOCKmJX1V0nttr7H9UElvaK1/AbZLOtT2gTXLrZF0l6Rdth8h6RULXH83vi3pUbafnUbavErSPkMvAYlAx/7zHkkTkn4l6WpJV6Z5knSspEsl7ZL0E0mfjIjLVLSfv1/FGfg2SQ9S0Tk4n9dIulvSTZJ+LOlLks5bSGER8WsVXzg3pREs8zXrSNKbJP2DpJ2SPiPpKwtZfzci4jZJz5P0QUl/lHS8in/H3Uu9beTHPOACyEcazrlV0gvSlx4wgzN0YMDZfprtg1LTU6vd/qd9LgsDiEAHBt8pkm5U0fT095KeHRF/7m9JGEQ0uQDAMsEZOgAsE7U3HLJ9nqRnStoREY9M8z6k4k+/+1T8KfiyiLijbl1r166No48+uquCAeD+ZuPGjbdFxFjdcrVNLrafqGI42efbAv2pkn4QEVO2PyBJEfGWuo2Nj4/HxMTEQuoHACS2N0bEeN1ytU0uEfFDSbfPmXdxuq+EVPS2c+UaAPRZL9rQ/1HF7UfnZfss2xO2JyYnJ3uwOQDAfLoKdNtvk9S6O9y8ImJDRIxHxPjYWG0TEACgQx0/hcX2S1V0lp4ajH0EgL7rKNBtP13SmyU9KSLu6W1JAIBO1Da52L5AxQ2TjrO91fbLJX1CxZ3nLrF9le1PL3GdAIAatWfoEfH8eWZ/dglqAQB0IYsrRb+/abs+efnmfpcBAAMti0C//PpJnfuj3/a7DAAYaFkEui01GUgDAJXyCHRJ5DkAVMsj0G0x1B0AqmUS6JyhA0CdPAJdFnkOANWyCPSGRZMLANTIItCLUS79rgIABlsWgd6wFTS6AEClLAJdnKEDQK0sAt2yOEEHgGpZBHrDoskFAGpkEeh0igJAvSwCvcGVogBQK4tAtzhDB4A6WQS67H5XAAADL4tAb6Q8p9kFAMplEehWkeg0uwBAuTwCnTN0AKiVRaDPNLn0twwAGGhZBLrdanIh0gGgTCaBXryS5wBQLo9AT52iBDoAlMsj0Gfa0El0AChTG+i2z7O9w/Y1bfMOsX2J7RvS68FLWiRNLgBQayFn6J+T9PQ5886W9P2IOFbS99P7JTM7Dp1EB4AytYEeET+UdPuc2WdIOj9Nny/p2T2uay9m2CIA1Oq0Df2wiLg1TW+TdFjZgrbPsj1he2JycrKjjbWGLXKCDgDluu4UjeLyzdKojYgNETEeEeNjY2MdbaN1ay6uFAWAcp0G+nbbh0tSet3Ru5L2RacoANTrNNAvkvSSNP0SSd/sTTnz40pRAKi3kGGLF0j6iaTjbG+1/XJJ75f0FNs3SDotvV8ydIoCQL3hugUi4vklH53a41pK0SkKAPXyuFI0vdIpCgDlsgj0RusMvc91AMAgyyLQudsiANTLI9DTK6NcAKBcFoFOkwsA1Msi0Fun6E2eEg0ApbII9NYZOgCgXBaBPjtssa9lAMBAyyPQW00uJDoAlMoi0OkUBYB6WQQ6Z+gAUC+LQG8hzwGgXBaBPjvKhUQHgDJZBPpsk0t/6wCAQZZFoDe4fS4A1Moi0GcbXEh0ACiTR6DPXPrf3zoAYJBlEuitceicoQNAmTwCPb3Shg4A5bIIdDpFAaBeFoE+88QimlwAoFRWgc44dAAol0mgt5pcSHQAKJNHoKdX4hwAyuUR6JyhA0CtrgLd9uttX2v7GtsX2F7Zq8LaNVqdouQ5AJTqONBtP0TSP0saj4hHShqSdGavCttrW6nRhU5RACjXbZPLsKQDbA9LeoCkW7ovaV+zZ+gkOgCU6TjQI+JmSf8m6feSbpV0Z0RcPHc522fZnrA9MTk52dnGZsahAwDKdNPkcrCkMyQdI2mdpFW2Xzh3uYjYEBHjETE+NjbW2bZmmlyIdAAo002Ty2mSfhsRkxGxR9LXJT2uN2XtrcG4RQCo1U2g/17SybYf4GJc4amSNvWmrL3N3m0RAFCmmzb0KyRdKOlKSVendW3oUV17acxc+k+kA0CZ4W5+OSLeKemdPaqllBmHDgC1srhSVHSKAkCtLAK9wbBFAKiVRaDbJDoA1Mkj0NMrTS4AUC6LQOcRdABQL4tAN8MWAaBWVoFOnANAuTwCXTS5AECdPAKd2+cCQK0sAr3BvVwAoFYWgc6l/wBQL4tA5+ZcAFAvi0BvXVpEnANAuSwCnU5RAKiXRaBzpSgA1Msi0GefQEeiA0CZPAK91Sna7G8dADDIsgh0xqEDQL0sAr2FTlEAKJdFoDcadIoCQJ0sAp1OUQCol0egz1wp2t86AGCQZRHojEMHgHpZBDpNLgBQL4tAF00uAFCrq0C3fZDtC23/2vYm26f0qrB2De6fCwC1hrv8/Y9J+l5EPNf2CkkP6EFN+5htcgEAlOk40G0fKOmJkl4qSRFxn6T7elPW3lpn6E3aXACgVDdNLsdImpT0n7Z/Yftc26t6VNdeZlpclmLlALBMdBPow5JOlPSpiHiMpLslnT13Idtn2Z6wPTE5OdnRhpwaXThBB4By3QT6VklbI+KK9P5CFQG/l4jYEBHjETE+NjbW0YbcmFlXZ5UCwP1Ax4EeEdsk/cH2cWnWqZKu60lVc7h+EQC43+t2lMtrJH0xjXC5SdLLui9pXzOdopyhA0CprgI9Iq6SNN6jWkoxDB0A6mVxpWirU5Q8B4ByeQT6zKX/RDoAlMkq0MlzACiXR6C3mlxIdAAolUWgNzhDB4BaWQS6TacoANTJItAbdIoCQK0sAt08gg4AamUR6C10igJAuWwCvWHa0AGgSjaBbpsmFwCokE2gN0ynKABUySbQLdPkAgAVsgl0mVEuAFAlm0BvmFEuAFAlm0CnyQUAquUT6JaaPCUaAEplE+gNc4YOAFWyCXSLTlEAqJJPoDMOHQAqZRTo7ncJADDQMgp0hi0CQJVsAr1hi0EuAFAum0C3pGCcCwCUyifQOUMHgEoZBTrDFgGgSteBbnvI9i9sf6sXBZVuRxKPuACAcr04Q3+tpE09WE+lhq1mc6m3AgD56irQbR8h6e8kndubcqq2RacoAFTp9gz9o5LeLKn03Nn2WbYnbE9MTk52vCEu/QeAah0Huu1nStoRERurlouIDRExHhHjY2NjnW6OUS4AUKObM/THS3qW7S2Svizpyba/0JOq5kGTCwBU6zjQI+KtEXFERBwt6UxJP4iIF/assjkaNk0uAFAhs3HoJDoAlBnuxUoi4nJJl/diXWWKS/8BAGWyOUPn5lwAUC2bQBdNLgBQKZtAp1MUAKplE+jcPhcAquUT6NxtEQAqZRPoRacoiQ4AZbIJdIkzdACokk2g26YFHQAqZBPoDYYtAkClbAKdTlEAqJZNoNMpCgDVsgl07uUCANWyCXRxpSgAVMom0BsWTS4AUCGbQHe/CwCAAZdNoHNzLgColk2gmyYXAKiUT6CLM3QAqJJPoHOGDgCVsgp04hwAyuUT6CLRAaBKNoHeaNDkAgBVsgl0i9vnAkCVfAKd2+cCQKWMAt1qkucAUKrjQLd9pO3LbF9n+1rbr+1lYftsT/SJAkCV4S5+d0rSGyPiSttrJG20fUlEXNej2vbCE4sAoFrHZ+gRcWtEXJmmd0raJOkhvSpsLnMvFwCo1JM2dNtHS3qMpCvm+ews2xO2JyYnJzvfhqSg0QUASnUd6LZXS/qapNdFxF1zP4+IDRExHhHjY2Nj3WxHzWYXhQLAMtdVoNseURHmX4yIr/empLJt0SkKAFW6GeViSZ+VtCkiPtK7kuZHpygAVOvmDP3xkl4k6cm2r0o/p/eorn1w+1wAqNbxsMWI+LH245PhiiYXEh0AymRzpWiDK0UBoFI2gS7a0AGgUjaB3jB3WwSAKtkEuiU6RQGgQj6BTpMLAFTKJtDpFAWAatkEOvdyAYBq2QR6Mcql30UAwODKJtAb3D4XACplE+jFKBcSHQDKZBPojEMHgGrZBLotNTlDB4BSWQU6eQ4A5TIKdMahA0CVfAJdEs8sAoBy2QQ6wxYBoFo2gU6nKABUyyfQRYMLAFTJJ9BtTU8HFxcBQIlsAv3Yw1Zr5+4p3bBjV79LAYCBlE2gn/aXh0mSLrlue58rAYDBlE2gH/bAlTrhyIP0vWu20ewCAPPIJtAl6bl/fYSuvvlOnXPZZkIdAOYY7ncBi/HCxx6liS23698u/o2u+sOdetnjj9Yp6w9Vo+F+lwYAfZdVoNvWh593go578Bp96vIbdemm7XrQmlE9dv2hOvGog/SwsdU6Zu0qrTvoAA0R8gDuZ9xN04Xtp0v6mKQhSedGxPurlh8fH4+JiYmOt9fu3j3Tuvi67br42m36+Zbbtf2u3TOfrRhu6MiDD9Da1aM6dPUKHbJqhQ5ZNapDVxXTa1YOa83KYa0eHdHqlcNaPVr88CUAYBDZ3hgR43XLdXyGbntI0jmSniJpq6Sf274oIq7rdJ2LsXJkSM86YZ2edcI6RYQmd+7WTbfdrS233a3f3na3fn/7Pfrjrvt0/baduv3u+/Sne/bUrvOAkSEdsGJIo8MNrRwpXkdHhrSy7f3KkSGtHGlodHhII0MNjQxZw0PWcKM13dBwwxoZamh4yBppFK/DQw2NNGY/bzSshotbGhQ/0lDDctt0w5bbplvLNey07DzTreUabetO01bxV46kND37HkD+umlyOUnS5oi4SZJsf1nSGZL2S6C3s60HPXClHvTAlTp5/aHzLjM13dSf7tmjP91zn3beO6Vdu6e0694p3b17SjvT9K7de3Tvnqbu3TOt3VPF673p9Y4/79HuPdPFvD1N3Ts1rT1TTU01Q1PN0PQyuBVkK9tbwe+2+ZZbd0ib/TJIM2zt82XRtnixrrnv27Ynec629133wvdhab6gFrPaRS27iL1b3HoXsewiVryof90BqHeQvO85j9JJxxyypNvoJtAfIukPbe+3Snrs3IVsnyXpLEk66qijuthcd4aHGhpbM6qxNaNLsv5mCvapZlN7pkNT00XY75luamq6fX4x3YzikXrTzVAzivvUFD/FulrT083i6tjp9H6f30nT01Es12yGptNyzQhNN2eXa33lREihmLnZWbRmpunW57PLqm3ZYsbsumJmmfnWPbNMxbZnl03ratvWYloEF/OVurj1Lk0Ri6t34Usv3b/DYta7NPXmfP+PVaNDS76NJe8UjYgNkjZIRRv6Um+vXxoNa0XDWpHXSFAAy0g36XOzpCPb3h+R5gEA+qCbQP+5pGNtH2N7haQzJV3Um7IAAIvVcZNLREzZfrWk/1UxbPG8iLi2Z5UBABalqzb0iPiOpO/0qBYAQBfowQOAZYJAB4BlgkAHgGWCQAeAZaKrm3MtemP2pKTfdfjrayXd1sNy+ol9GUzsy2BiX6SHRsRY3UL7NdC7YXtiIXcbywH7MpjYl8HEviwcTS4AsEwQ6ACwTOQU6Bv6XUAPsS+DiX0ZTOzLAmXThg4AqJbTGToAoAKBDgDLRBaBbvvptq+3vdn22f2uZ7Fsb7F9te2rbE+keYfYvsT2Den14H7XOR/b59neYfuatnnz1u7Cx9Nx+pXtE/tX+d5K9uNdtm9Ox+Uq26e3ffbWtB/X235af6qen+0jbV9m+zrb19p+bZqf43Ep25fsjo3tlbZ/ZvuXaV/eneYfY/uKVPNX0u3GZXs0vd+cPj+66yIiPbpsUH9U3Jr3RknrJa2Q9EtJx/e7rkXuwxZJa+fM+6Cks9P02ZI+0O86S2p/oqQTJV1TV7uk0yV9V8UjIk+WdEW/66/Zj3dJetM8yx6f/jsblXRM+u9vqN/70Fbf4ZJOTNNrJP0m1ZzjcSnbl+yOTfr3XZ2mRyRdkf69vyrpzDT/05JekaZfKenTafpMSV/ptoYcztBnHkYdEfdJaj2MOndnSDo/TZ8v6dl9rKVURPxQ0u1zZpfVfoakz0fhp5IOsn34/qm0Wsl+lDlD0pcjYndE/FbSZhX/HQ6EiLg1Iq5M0zslbVLxjN8cj0vZvpQZ2GOT/n13pbcj6SckPVnShWn+3OPSOl4XSjrVXT4BO4dAn+9h1FUHfBCFpIttb0wPzZakwyLi1jS9TdJh/SmtI2W153isXp2aIc5ra/bKZj/Sn+mPUXE2mPVxmbMvUobHxvaQ7ask7ZB0iYq/IO6IiKm0SHu9M/uSPr9T0qHdbD+HQF8OnhARJ0p6hqRX2X5i+4dR/M2V5fjRnGuX9ClJD5P0aEm3Svpwf8tZHNurJX1N0usi4q72z3I7LvPsS5bHJiKmI+LRKp6xfJKkR+zP7ecQ6Nk/jDoibk6vOyR9Q8WB3t76sze97uhfhYtWVntWxyoitqf/AZuSPqPZP90Hfj9sj6gIwC9GxNfT7CyPy3z7kvOxkaSIuEPSZZJOUdHE1Xo6XHu9M/uSPj9Q0h+72W4OgZ71w6htr7K9pjUt6amSrlGxDy9Ji71E0jf7U2FHymq/SNKL06iKkyXd2dYEMHDmtCM/R8VxkYr9ODONQjhG0rGSfra/6yuT2lk/K2lTRHyk7aPsjkvZvuR4bGyP2T4oTR8g6Skq+gQuk/TctNjc49I6Xs+V9IP0l1Xn+t0zvMDe49NV9H7fKOlt/a5nkbWvV9Er/0tJ17bqV9FW9n1JN0i6VNIh/a61pP4LVPzJu0dF+9/Ly2pX0ct/TjpOV0sa73f9NfvxX6nOX6X/uQ5vW/5taT+ul/SMftc/Z1+eoKI55VeSrko/p2d6XMr2JbtjI+mvJP0i1XyNpHek+etVfOlslvTfkkbT/JXp/eb0+fpua+DSfwBYJnJocgEALACBDgDLBIEOAMsEgQ4AywSBDgDLBIEOAMsEgQ4Ay8T/A9dRuQ0FDK7OAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f1b94ee1470>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAHDtJREFUeJzt3XuYXHWd5/H3p7vTSUhCuKRBSEIuTATiAIoRURnBYRguPhJn0Nkwz8zqemEcxZlxvSyMLsuwj7s6z844uqIOjqzghcCwOmZn4mRVYHjWGzQCkYDBNgjpJkqTG+TWt/ruH+d0UulUnap0V6f6V3xez1NPVZ3zq3O+p0/y6V//zqlzFBGYmVlraWt2AWZm1ngOdzOzFuRwNzNrQQ53M7MW5HA3M2tBDnczsxbkcDdrIEmnSXpY0guS/myS1vGXkv6h0W2ttcjnudtUJelC4KsRsaCBy1wMPAlMi4jhRi23bPlfAp6PiA9UmX8v2TY5cG1SueduDSOpo55pk1xD+5FcXwWLgA3j/fCR/nlZ63K4GwCSFkr6hqR+SVslfTaf3ibpY5KekvSspNskzc3nLZYUkt4p6Wng7krT8rbnSfqBpB2SHsl75aPrPk7S/5L0jKTtkv5J0izg28DJknblj5Mr1P1lSZ+XtFbSbuANkt4o6SFJz0vaLOmGso/clz/vyJf5mnw575D0eL7+dZIWFfysrpC0Id+WeyWdkU+/G3gD8Nl82S8d87mPA79VNn/0ZxyS3ifp58DP82mfzmt/XtKDkn6rbDk3SPrqmH3wNklPS3pO0kfH2XampFvzn8Hjkj4iqbfaz8GmuIjw40X+ANqBR4BPAbOAGcD5+bx3AD3AUmA28A3gK/m8xUAAt+Wfm1ll2nxgK3A5WYfi4vx9V76cfwHuAI4FpgEX5NMvBHpr1P5lYCfwunzZM/LPnZm/Pwv4NfDmMTV3lC1jZb6NZwAdwMeAH1RZ30uB3fk2TAM+kn+2M59/L/CugnoPmZ/X8x3gOGBmPu2PgOPzej4I/AqYkc+7gWxop3x7vpj/rM8GBoAzxtH2E8C/5fthAbC+1s/fj6n7aHoBfjT/AbwG6C8PvLJ53wPeW/b+NGAoD53RsFhaNr/StP9E/guhbNo64G3ASUAJOLbCuusN99tqtPk74FNj6isP928D7yx73wbsARZVWNZ/Bu4c07YPuDB/P95w/+0a27AdODt/XSmwF5S1vR9YNY62m4BLyua9y+Ge7sPDMgawEHgqKh9gPBl4quz9U2TBfmLZtM0VPlc+bRHw1nwYY4ekHcD5ZMG+ENgWEdsnUP9B65f0akn35ENMO4H3APMKPr8I+HRZbdsAkf3FMdZBP4+IKOXrr9R2ItvwoXxoZGde09wa2/Crstd7yP7KOty2J4+po9J+tUQ43A2y/8SnVDmY9wxZ+I06BRgmG+oYVemUq/Jpm8l67seUPWZFxCfyecdJOqbGMoqMbfd1YA2wMCLmAl8gC+tqy9wM/MmY+mZGxA8qtD3o5yFJZL+g+sZZ6yHT8/H1jwB/QPYXzTFkQ0+q8tlG2UI2HDNq4SSvzyaRw90g+9N8C/AJSbMkzZD0unze7cAHJC2RNBv4b8AdVXr51XwVeJOkSyS158u/UNKCiNhCNizyOUnHSpom6fX5534NHD96APcwzCH7a2CfpHOBPyyb1082DLS0bNoXgOskvQxA0lxJb62y7DuBN0q6SNI0svHwAaDSL4JKfj1m3dXqH85r7ZB0PXB0ncufiDvJfg7HSpoPXHME1mmTxOFuRMQI8CbgN4CngV7g3+WzbwG+QnaWyZPAPuD9h7n8zWQHLf+SLLA2Ax/mwL+/PyYbx/8Z8CzwF/nnfkb2y2VTPmRyyNkyVbwXuFHSC8D1ZKE1Wsse4OPA9/NlnhcR3wQ+CayW9DzwKHBZlW3ZSHaw838Cz5H93N4UEYN11vZp4C35GSmfqdJmHfCvwBNkQ0D7ODJDJDeS7fsnge8Cd5H94rIE+UtMZlaRpD8lO9h6QbNrscPnnruZASDpJEmvU/bdhtPIhpy+2ey6bHz8bTgzG9UJ/D2wBNgBrAY+19SKbNw8LGNm1oI8LGNm1oKaNiwzb968WLx4cbNWb2aWpAcffPC5iOiq1a5p4b548WK6u7ubtXozsyRJeqp2Kw/LmJm1JIe7mVkLcribmbUgh7uZWQtyuJuZtaCa4S7pFmW3V3u0ynxJ+oykHknrJZ3T+DLNzOxw1NNz/zJwacH8y4Bl+eNq4PMTL8vMzCai5nnuEXGfpMUFTVaS3eYsgB9JOkbSSfl1um0KigiGS8HQSImhkdHnEkPDwVDpwOuR/HZdpfxOcKWAyG/NWAoIIn8PpQiC7JnR92XTRz832mZ0Hge9H73tY7bs/evLXzO63v3rP9D+oPVP6GczoY9PYL3Nq3uimzyxdSe63RPcXxedcSJnL6x0f5rGacSXmOZz8LWme/Nph4S7pKvJeveccsopDVh16xoeKbFrYJjhUrBt9yADQyWe2z3Atl2D7BkcZu/QCHsGR9g7NMK+wQOv9+bPewZH2DdU9npwhKFSiZFSMDTi6wmZTZQmcF+sE46ekUS41y0ibgZuBlixYsWLLmEigp17h+h/YSB77Dr4+bldg2zd/3og7zEX6+xo46jOdmZOa2fm6PO0dubM6OCEOdOzeZ3tTO9oZ3pHG21tYlp7G53toqO97ZDX09qz+R1toqNdSEJAm4QEQrQJUD4NaGvLnpW3aRvzGcpeH3gG8mVJ+TP5OsZMO2R9o7UUrG8iJrIMTeBOeBOtfSIf1wRXPrF1T2jVE669VTUi3Ps4+F6LC6j/fpItZWB4hL7te9m8fS+bt+1h665B+nft46mte3h62x627NjH4EjpkM91trcxb3Yn8+ZM58SjZ/Cyk4/mJUfPYO5RnbQJuuZMZ3pHO8fP7uT4WZ0c1dnBUZ3tzJjWTnub/2Gb2aEaEe5rgGskrQZeDex8MYy3v7BviEf7nuenfTtY37uTn/bt5Oltew4Zips7cxqLjj+KM+fP5dLffAknzJlB15zpdM2eTtecTrpmz+DomR3ufZhZQ9UMd0m3AxcC8yT1Av8FmAYQEV8A1gKXAz3AHuA/TFaxzfTMjr1872fP8pOntrO+dwebntu9P8jnHzOTsxbM5c0vn88pxx3FwuOOYuFxM+maPZ2Odn+VwMyOvHrOlrmqxvwA3tewiqaI4ZESD2/ewT0bn+Wen/Xz2JbngWyI5OwFc7ni7PmctXAuZ86fy7zZ05tcrZnZwXybvTL7hkZYt+FXfPfxZ7nviX527h2ivU28ctGxXHvZ6fzOGSdyatcsD6GY2ZTncAc2/uoF1jzSx+r7N7N19yDzZk/n4uUn8obTTuD8ZfOYO3Nas0s0MzssL9pw37Z7kG893Mc/dvfy2JbnkeCi00/g7a9dwmtPPZ42n4ViZgl70YV7/wsDfO7eHr72o6cZHClx5vy53PCm5bzxrJPpmuOxczNrDS+acN+5d4gv3reJW77/JAPDJa48Zz7vOH8Jp7/k6GaXZmbWcC0f7hHBN37Sx43//Bg79w7xprNP5gO/s4ylXbObXZqZ2aRp6XAfGilx/bc2cPv9T/OqxcdywxUv42Unz212WWZmk65lw33HnkHe+7Wf8INfbOU9F5zKhy85zV/VN7MXjZYM9039u3jnrd30bd/L37z1bK585YJml2RmdkS1XLhveGYnV938Izra2/jau1/NqxYf1+ySzMyOuJYK990Dw7z/6w8xs7Odu97zWhYed1SzSzIza4qWCvfP3tPDpud2c/u7z3Owm9mLWstcsrB3+x6+9P+e5PdfMZ/XnHp8s8sxM2uqlgn3v/7XjQj40CWnNbsUM7Oma4lwX9+7gzWPPMPVr1/KycfMbHY5ZmZN1xLh/vf/tomjZ3TwJxec2uxSzMymhOTDvXf7Hr796BauevUpzJ7eUseHzczGLflw/+f1WygF/PF5i5pdipnZlJF8uP/wF1v5jRNms+BYn/poZjYq6XAfGinR/cttnLfU30I1MyuXdLg/2reT3YMjvGbpvGaXYmY2pSQd7g89vQOAVy05tsmVmJlNLUmH+y+37mbOjA66Zvv2eGZm5RIP9z0smTcLyddpNzMrl3a4P7ebRcfPanYZZmZTTl3hLulSSRsl9Ui6tsL8RZK+J2m9pHslTfrdMQaHS/Ru38Pi430KpJnZWDXDXVI7cBNwGbAcuErS8jHN/gdwW0ScBdwI/PdGFzpW7/Y9lAL33M3MKqin534u0BMRmyJiEFgNrBzTZjlwd/76ngrzG+6prXsAWDLPPXczs7HqCff5wOay9735tHKPAL+fv/49YI6kQy6qLulqSd2Suvv7+8dT736bt2fhvtDfTDUzO0SjDqh+CLhA0kPABUAfMDK2UUTcHBErImJFV1fXhFb4/N4hAI6eOW1CyzEza0X1XEaxD1hY9n5BPm2/iHiGvOcuaTZwZUTsaFSRlewaGGFau5jekfQJP2Zmk6KeZHwAWCZpiaROYBWwpryBpHmSRpd1HXBLY8s81O6BYWZN7/A57mZmFdQM94gYBq4B1gGPA3dGxAZJN0q6Im92IbBR0hPAicDHJ6ne/XYPDDOr09dvNzOrpK50jIi1wNox064ve30XcFdjSyu2a2DYN+cwM6si2QHr3YPDzJre3uwyzMympGTDfde+YWbP8JkyZmaVpBvuA8PMds/dzKyiZMN998CID6iamVWRcLhnp0Kamdmhkgz3iGD34DBzZjjczcwqSTLc9w6NUArcczczqyLJcN+1bxhwuJuZVZNmuA9k4e6zZczMKksy3HcPZBec9NkyZmaVJRnu+3vuPqBqZlZRkuG+e/+wjMPdzKySNMN90AdUzcyKJBnuL+xzz93MrEiS4b5vKDugOmOaz5YxM6skyXCPyJ7b23wXJjOzSpIM91Ke7o52M7PKEg337LnN9081M6soyXAP8p67s93MrKI0w909dzOzQkmGeykfl/HxVDOzytIMd/fczcwKJRruHnM3MyuSZLjnHXfkdDczqyjNcI/weLuZWYEkw70U4fF2M7MCdYW7pEslbZTUI+naCvNPkXSPpIckrZd0eeNLPaAUPphqZlakZrhLagduAi4DlgNXSVo+ptnHgDsj4hXAKuBzjS60XCnC1x4wMytQT8/9XKAnIjZFxCCwGlg5pk0AR+ev5wLPNK7EQ0X4HHczsyL1hPt8YHPZ+958WrkbgD+S1AusBd5faUGSrpbULam7v79/HOVmwmPuZmaFGnVA9SrgyxGxALgc+IqkQ5YdETdHxIqIWNHV1TXulXnM3cysWD3h3gcsLHu/IJ9W7p3AnQAR8UNgBjCvEQVWUorwF5jMzArUE+4PAMskLZHUSXbAdM2YNk8DFwFIOoMs3Mc/7lJDuOduZlaoZrhHxDBwDbAOeJzsrJgNkm6UdEXe7IPAuyU9AtwOvD1i9NqNjeeeu5lZsbruMB0Ra8kOlJZPu77s9WPA6xpbWlE97rmbmRVJ+Buqza7CzGzqSjTcfdEwM7MiSYa7LxxmZlYsyXAvRSBff8DMrKokw92XHzAzK5ZkuHvM3cysWJLhHhG0JVm5mdmRkWRE+mYdZmbFEg13X87dzKxIouHunruZWZEkwz3A15YxMyuQZri7525mVijJcC+VfOEwM7MiaYa7L/lrZlYo0XD3l5jMzIokGe7gC4eZmRVJMtx9g2wzs2KJhrt77mZmRRINd4+5m5kVSTLcw2fLmJkVSjTcPeZuZlYkyXD3mLuZWbFkw91j7mZm1SUa7r7NnplZkSTD3RcOMzMrVle4S7pU0kZJPZKurTD/U5Iezh9PSNrR+FIPyE6FnMw1mJmlraNWA0ntwE3AxUAv8ICkNRHx2GibiPhAWfv3A6+YhFr3c8/dzKxYPT33c4GeiNgUEYPAamBlQfurgNsbUVw1/hKTmVmxesJ9PrC57H1vPu0QkhYBS4C7q8y/WlK3pO7+/v7DrXW/8KmQZmaFGn1AdRVwV0SMVJoZETdHxIqIWNHV1TXulfjCYWZmxeoJ9z5gYdn7Bfm0SlYxyUMykJ/nPtkrMTNLWD3h/gCwTNISSZ1kAb5mbCNJpwPHAj9sbImHCo+5m5kVqhnuETEMXAOsAx4H7oyIDZJulHRFWdNVwOqIiMkp9QBffsDMrFjNUyEBImItsHbMtOvHvL+hcWXVqsdj7mZmRZL8hmopgrYkKzczOzKSjMjsgKp77mZm1SQZ7uHLD5iZFUoz3PGYu5lZkSTD3WfLmJkVSzjcne5mZtWkGe4lf4nJzKxIkuEeET6gamZWIM1wx7fZMzMrkmS4e8zdzKxYouHuMXczsyJJhrtv1mFmVizJcPcNss3MiiUZ7r5BtplZsSTD3bfZMzMrlmi4+zx3M7MiSYa7b9ZhZlYsyXD3DbLNzIolG+5tPhfSzKyqJMPdN+swMyuWbLh7zN3MrLokw9036zAzK5ZwuDvdzcyqSTTc8dkyZmYFkgv3iAB8VUgzsyJ1hbukSyVtlNQj6doqbf5A0mOSNkj6emPLPCDPdg/LmJkV6KjVQFI7cBNwMdALPCBpTUQ8VtZmGXAd8LqI2C7phMkquJSnuw+omplVV0/P/VygJyI2RcQgsBpYOabNu4GbImI7QEQ829gyDyiN9tyd7mZmVdUT7vOBzWXve/Np5V4KvFTS9yX9SNKllRYk6WpJ3ZK6+/v7x1XwaM/dzMyqa9QB1Q5gGXAhcBXwRUnHjG0UETdHxIqIWNHV1TWhFXrM3cysunrCvQ9YWPZ+QT6tXC+wJiKGIuJJ4AmysG84j7mbmdVWT7g/ACyTtERSJ7AKWDOmzT+R9dqRNI9smGZTA+vcr+SzZczMaqoZ7hExDFwDrAMeB+6MiA2SbpR0Rd5sHbBV0mPAPcCHI2LrZBRc2n+e+2Qs3cysNdQ8FRIgItYCa8dMu77sdQD/MX9Mqihlz+65m5lVl9w3VN1zNzOrLblwHz0R0j13M7Pqkgt3ny1jZlZbsuHuC4eZmVWXXLj7wmFmZrUlF+4+oGpmVlty4X6g597cOszMprLkwt1j7mZmtSUX7h5zNzOrLblw96mQZma1JRju2bM77mZm1SUY7qM9d6e7mVk1yYV77O+5O9zNzKpJMNw95m5mVkty4e6bdZiZ1ZZguLvnbmZWS7LhDk53M7Nqkgt3X37AzKy2hMPd6W5mVk1y4b5/zD25ys3MjpzkItIXDjMzqy3BcM+eHe1mZtUlF+6jt8j2mLuZWXXJhbu/xGRmVlt64V7yl5jMzGqpK9wlXSppo6QeSddWmP92Sf2SHs4f72p8qZmSLxxmZlZTR60GktqBm4CLgV7gAUlrIuKxMU3viIhrJqHGg4RvkG1mVlM9PfdzgZ6I2BQRg8BqYOXkllWdx9zNzGqrJ9znA5vL3vfm08a6UtJ6SXdJWlhpQZKultQtqbu/v38c5ULgMXczs1oadUD1/wCLI+Is4DvArZUaRcTNEbEiIlZ0dXWNa0Ueczczq62ecO8DynviC/Jp+0XE1ogYyN/+A/DKxpR3KF/y18ystnrC/QFgmaQlkjqBVcCa8gaSTip7ewXweONKPFj4HqpmZjXVPFsmIoYlXQOsA9qBWyJig6Qbge6IWAP8maQrgGFgG/D2ySq4VMqene1mZtXVDHeAiFgLrB0z7fqy19cB1zW2tCq15M/uuZuZVZfeN1R9nruZWU3JhbvH3M3Maksu3P0lJjOz2hIMdw/LmJnVkly4+wbZZma1JRfuvs2emVltyYV7eMzdzKym5MLdlx8wM6stwXDPnt1zNzOrLsFwj9qNzMxe5JIL99HrD7R5XMbMrKrkwt1j7mZmtSUY7tmzx9zNzKpLMNz9DVUzs1qSC/fRC4cJp7uZWTXphXv+7DF3M7Pqkgv3UsmX/DUzqyW9cPcBVTOzmhIM93zMPbnKzcyOnOQicvQLqu63m5lVl1644zF3M7Nakgv3JfNm88YzT6Ldp8uYmVXV0ewCDtfFy0/k4uUnNrsMM7MpLbmeu5mZ1eZwNzNrQXWFu6RLJW2U1CPp2oJ2V0oKSSsaV6KZmR2umuEuqR24CbgMWA5cJWl5hXZzgD8HftzoIs3M7PDU03M/F+iJiE0RMQisBlZWaPdfgU8C+xpYn5mZjUM94T4f2Fz2vjeftp+kc4CFEfEvDazNzMzGacIHVCW1AX8LfLCOtldL6pbU3d/fP9FVm5lZFfWEex+wsOz9gnzaqDnAbwL3SvolcB6wptJB1Yi4OSJWRMSKrq6u8VdtZmaFNHrzi6oNpA7gCeAislB/APjDiNhQpf29wIciorvGcvuBp8ZRM8A84Llxfnaq8bZMTd6WqcnbAosiombvuOY3VCNiWNI1wDqgHbglIjZIuhHojog14yiOeoqrRlJ3RLTE6ZbelqnJ2zI1eVvqV9flByJiLbB2zLTrq7S9cOJlmZnZRPgbqmZmLSjVcL+52QU0kLdlavK2TE3eljrVPKBqZmbpSbXnbmZmBRzuZmYtKLlwr/cKlVOVpF9K+qmkhyV159OOk/QdST/Pn49tdp2VSLpF0rOSHi2bVrF2ZT6T76f1+SUqpowq23KDpL583zws6fKyedfl27JR0iXNqfpQkhZKukfSY5I2SPrzfHpy+6VgW1LcLzMk3S/pkXxb/iqfvkTSj/Oa75DUmU+fnr/vyecvnnAREZHMg+w8+18AS4FO4BFgebPrOsxt+CUwb8y0vwauzV9fC3yy2XVWqf31wDnAo7VqBy4Hvk12L/PzgB83u/46tuUGsi/gjW27PP+3Nh1Ykv8bbG/2NuS1nQSck7+eQ/aFw+Up7peCbUlxvwiYnb+eRna13POAO4FV+fQvAH+av34v8IX89SrgjonWkFrPvd4rVKZmJXBr/vpW4M1NrKWqiLgP2DZmcrXaVwK3ReZHwDGSTjoyldZWZVuqWQmsjoiBiHgS6CH7t9h0EbElIn6Sv34BeJzswn7J7ZeCbalmKu+XiIhd+dtp+SOA3wbuyqeP3S+j++su4CJJE7pRdGrhXvMKlQkI4P9KelDS1fm0EyNiS/76V0BKN4mtVnuq++qafLjilrLhsSS2Jf9T/hVkvcSk98uYbYEE94ukdkkPA88C3yH7y2JHRAznTcrr3b8t+fydwPETWX9q4d4Kzo+Ic8hufvI+Sa8vnxnZ32VJnp+acu25zwOnAi8HtgB/09xy6idpNvC/gb+IiOfL56W2XypsS5L7JSJGIuLlZBdbPBc4/UiuP7Vwr3WFyikvIvry52eBb5Lt9F+P/mmcPz/bvAoPW7Xak9tXEfHr/D9kCfgiB/7En9LbImkaWRh+LSK+kU9Ocr9U2pZU98uoiNgB3AO8hmwYbPSyL+X17t+WfP5cYOtE1ptauD8ALMuPOHeSHXgY14XLmkHSLGW3I0TSLOB3gUfJtuFtebO3Ad9qToXjUq32NcC/z8/OOA/YWTZMMCWNGXv+PbJ9A9m2rMrPaFgCLAPuP9L1VZKPy34JeDwi/rZsVnL7pdq2JLpfuiQdk7+eCVxMdgzhHuAtebOx+2V0f70FuDv/i2v8mn1UeRxHoS8nO4r+C+Cjza7nMGtfSnZ0/xFgw2j9ZGNr3wN+DnwXOK7ZtVap/3ayP4uHyMYL31mtdrKzBW7K99NPgRXNrr+ObflKXuv6/D/bSWXtP5pvy0bgsmbXX1bX+WRDLuuBh/PH5Snul4JtSXG/nAU8lNf8KHB9Pn0p2S+gHuAfgen59Bn5+558/tKJ1uDLD5iZtaDUhmXMzKwODnczsxbkcDcza0EOdzOzFuRwNzNrQQ53M7MW5HA3M2tB/x/NWS/P40mDvQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f1b33dccf60>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 画图输出训练过程情况\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "% matplotlib inline\n",
    "# 画训练过程中的损失值图像\n",
    "lx = np.array(range(len(loss)))\n",
    "ly = np.array(loss)\n",
    "plt.title('loss of training')\n",
    "plt.plot(lx, ly)\n",
    "plt.show()\n",
    "\n",
    "\n",
    "# 画训练过程中正确率变化\n",
    "cx = np.array(range(len(correct)))\n",
    "cy = np.array(correct)\n",
    "plt.title('correct rate of training')\n",
    "plt.plot(cx, cy)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "write done.\n"
     ]
    }
   ],
   "source": [
    "# 引入测试数据\n",
    "\n",
    "with open('./data/test.csv') as f :\n",
    "    lines = csv.reader(f)\n",
    "    test = []\n",
    "    for line in lines :\n",
    "        if lines.line_num == 1 :\n",
    "            continue\n",
    "        test.append([float(i) for i in line])\n",
    "test = torch.FloatTensor(test)\n",
    "net.eval()\n",
    "# volatile = True 表示前向传播不保留缓存\n",
    "predict = net(Variable(test, volatile=True))\n",
    "_, predict = torch.max(predict, 1)\n",
    "predict = predict.data.numpy()\n",
    "\n",
    "with open('./data/predict.csv', 'w') as f :\n",
    "    writer = csv.writer(f)\n",
    "    writer.writerow(['ImageId', 'Label'])\n",
    "    for i in range(predict.shape[0]) :\n",
    "        result = [i + 1, predict[i]]\n",
    "        writer.writerow(result)\n",
    "    print('write done.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
